import argparse
import os
import random
import time

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default=os.path.basename(__file__).rstrip('.py'))
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--torch_deterministic', type=bool, default=True)
    parser.add_argument('--cuda', type=bool, default=True)
    parser.add_argument('--env_id', type=str, default='Ant-v4')
    parser.add_argument('--total_time_steps', type=int, default=int(1e7))
    parser.add_argument('--learning_rate', type=float, default=3e-4)
    parser.add_argument('--num_envs', type=int, default=8)
    parser.add_argument('--num_steps', type=int, default=256)
    parser.add_argument('--anneal_lr', type=bool, default=True)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--gae_lambda', type=float, default=0.95)
    parser.add_argument('--num_mini_batches', type=int, default=4)
    parser.add_argument('--update_epochs', type=int, default=10)
    parser.add_argument('--norm_adv', type=bool, default=True)
    parser.add_argument('--clip_value_loss', type=bool, default=True)
    parser.add_argument('--c_1', type=float, default=0.5)
    parser.add_argument('--c_2', type=float, default=0.0)
    parser.add_argument('--max_grad_norm', type=float, default=0.5)
    parser.add_argument('--kld_max', type=float, default=0.02)
    a = parser.parse_args()
    a.batch_size = int(a.num_envs * a.num_steps)
    a.minibatch_size = int(a.batch_size // a.num_mini_batches)
    return a


def make_env(env_id, gamma):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.FlattenObservation(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda o: np.clip(o, -10, 10))
        env = gym.wrappers.NormalizeReward(env, gamma=gamma)
        env = gym.wrappers.TransformReward(env, lambda r: float(np.clip(r, -10, 10)))
        return env
    return thunk


def layer_init(layer, s=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, s)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, e):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), s=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(e.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.array(e.single_action_space.shape).prod()), s=0.01),
        )
        self.actor_log_std = nn.Parameter(torch.zeros(1, np.array(e.single_action_space.shape).prod()))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, a=None, show_all=False):
        action_mean = self.actor_mean(x)
        action_log_std = self.actor_log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)
        probs = Normal(action_mean, action_std)
        if a is None:
            a = probs.sample()
        if show_all:
            return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x), probs
        return a, probs.log_prob(a).sum(1), probs.entropy().sum(1), self.critic(x)


def compute_kld(mu_1, sigma_1, mu_2, sigma_2):
    return torch.log(sigma_2 / sigma_1) + ((mu_1 - mu_2) ** 2 + (sigma_1 ** 2 - sigma_2 ** 2)) / (2 * sigma_2 ** 2)


def main(env_id, seed):
    args = get_args()
    args.env_id = env_id
    args.seed = seed
    run_name = (
            'spo' +
            '_epoch_' + str(args.update_epochs) +
            '_seed_' + str(args.seed)
    )

    # 保存训练日志
    path_string = str(args.env_id) + '/' + run_name
    writer = SummaryWriter(path_string)
    writer.add_text(
        'Hyperparameter',
        '|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{key}|{value}|' for key, value in vars(args).items()])),
    )

    # 随机种子
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    # 初始化环境
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.gamma) for _ in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'

    agent = Agent(envs).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # 初始化存储
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    log_probs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)
    mean = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    std = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)

    # 开始收集数据
    global_step = 0
    start_time = time.time()
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs).to(device)
    num_updates = args.total_time_steps // args.batch_size

    for update in tqdm(range(1, num_updates + 1)):

        # 学习率线性递减
        if args.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lr_now = frac * args.learning_rate
            optimizer.param_groups[0]['lr'] = lr_now

        for step in range(0, args.num_steps):
            global_step += 1 * args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # 计算旧的策略网络输出动作概率分布的对数
            with torch.no_grad():
                action, log_prob, _, value, mean_std = agent.get_action_and_value(next_obs, show_all=True)
                values[step] = value.flatten()
            actions[step] = action
            log_probs[step] = log_prob

            # 均值和方差：(2048, num_envs, action_space_n)
            mean[step] = mean_std.loc
            std[step] = mean_std.scale

            # 更新环境
            next_obs, reward, terminations, truncations, info = envs.step(action.cpu().numpy())
            done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

            if 'final_info' not in info:
                continue

            for item in info['final_info']:
                if item is None:
                    continue
                # print(f"global_step={global_step}, episodic_return={float(item['episode']['r'][0])}")
                writer.add_scalar('charts/episodic_return', item['episode']['r'][0], global_step)

        # GAE估计优势函数
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            last_gae_lam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    next_non_terminal = 1.0 - next_done
                    next_values = next_value
                else:
                    next_non_terminal = 1.0 - dones[t + 1]
                    next_values = values[t + 1]
                delta = rewards[t] + args.gamma * next_values * next_non_terminal - values[t]
                advantages[t] = last_gae_lam = delta + args.gamma * args.gae_lambda * next_non_terminal * last_gae_lam
            returns = advantages + values

        # ------------------------------- 上面收集了足够的数据，下面开始更新 ------------------------------- #
        # 将每个batch展平
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_log_probs = log_probs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # 得到一个batch中的均值和方差
        b_mean = mean.reshape(args.batch_size, -1)
        b_std = std.reshape(args.batch_size, -1)

        # 更新策略网络和价值网络
        b_index = np.arange(args.batch_size)
        for epoch in range(1, args.update_epochs + 1):
            np.random.shuffle(b_index)
            t = 0
            for start in range(0, args.batch_size, args.minibatch_size):
                t += 1
                end = start + args.minibatch_size
                mb_index = b_index[start:end]

                # 得到最新的策略网络和价值网络输出
                _, new_log_prob, entropy, new_value, new_mean_std = agent.get_action_and_value(b_obs[mb_index],
                                                                                               b_actions[mb_index],
                                                                                               show_all=True)
                # 计算平均kl散度
                new_mean = new_mean_std.loc.reshape(args.minibatch_size, -1)
                new_std = new_mean_std.scale.reshape(args.minibatch_size, -1)
                d = compute_kld(b_mean[mb_index], b_std[mb_index], new_mean, new_std).sum(1)

                writer.add_scalar('charts/average_kld', d.mean(), global_step)
                writer.add_scalar('others/min_kld', d.min(), global_step)
                writer.add_scalar('others/max_kld', d.max(), global_step)

                log_ratio = new_log_prob - b_log_probs[mb_index]
                ratios = log_ratio.exp()
                mb_advantages = b_advantages[mb_index]

                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-12)

                # 策略网络损失
                if epoch == 1 and t == 1:
                    pg_loss = (-mb_advantages * ratios).mean()
                else:
                    # d_clip
                    d_clip = torch.clamp(input=d, min=0, max=args.kld_max)
                    # d_clip / d
                    ratio = d_clip / (d + 1e-12)
                    # sign_a
                    sign_a = torch.sign(mb_advantages)
                    # (d_clip / d + sign_a - 1) * sign_a
                    result = (ratio + sign_a - 1) * sign_a
                    pg_loss = (-mb_advantages * ratios * result).mean()

                # 价值网络损失
                new_value = new_value.view(-1)
                if args.clip_value_loss:
                    v_loss_un_clipped = (new_value - b_returns[mb_index]) ** 2
                    v_clipped = b_values[mb_index] + torch.clamp(
                        new_value - b_values[mb_index],
                        -0.2,
                        0.2,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_index]) ** 2
                    v_loss_max = torch.max(v_loss_un_clipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((new_value - b_returns[mb_index]) ** 2).mean()

                # 策略网络输出概率分布的熵
                entropy_loss = entropy.mean()

                # 总损失
                loss = pg_loss + v_loss * args.c_1 - entropy_loss * args.c_2

                # 写入训练过程中的数据
                writer.add_scalar('losses/policy_loss', pg_loss.item(), global_step)
                writer.add_scalar('losses/value_loss', v_loss.item(), global_step)
                writer.add_scalar('losses/entropy', entropy_loss.item(), global_step)

                # 更新网络参数
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

        y_pre, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pre) / var_y
        writer.add_scalar('others/explained_variance', explained_var, global_step)

        # 写入训练过程中的数据
        writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
        writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)
        # print('SPS:', int(global_step / (time.time() - start_time)))

    envs.close()
    writer.close()


def run():
    for env_id in [
        'Humanoid-v4',
        'Ant-v4',
        'HalfCheetah-v4',
        'Hopper-v4',
        'Humanoid-v4',
        'HumanoidStandup-v4',
        'InvertedDoublePendulum-v4'
    ]:
        for seed in range(1, 4):
            print(env_id, 'seed:', seed)
            main(env_id, seed)


if __name__ == '__main__':
    run()
